import taichi as ti

from ..shape import Shape
from ..utils import Vec3, Inf, Ray3
from ..config import EPS_CYLINDER


class Cylinder(Shape):
    """圆柱面, 不包含圆柱上下表面"""

    def __init__(self, r: float, h: float = 1):
        super().__init__()
        Shape.shapes[self.index].attribute[0] = r
        Shape.shapes[self.index].attribute[1] = h

    @ti.func
    def intersect(self: int, ray: Ray3, closest: float):
        r = Shape.shapes[self].attribute[0]
        root, norm = Inf, Vec3(0)
        o, d = ray.origin.xz, ray.direct.xz
        a = d.norm_sqr()
        b = -o.dot(d)
        discriminant = b * b - (o.norm_sqr() - r**2) * a
        if discriminant > EPS_CYLINDER:
            h = Shape.shapes[self].attribute[1]
            sqrt = discriminant**0.5
            root_ = (b - sqrt) / a
            if EPS_CYLINDER <= root_ <= closest and 0 < ray.atY(root_) < h:
                root = root_
            else:
                root_ = (b + sqrt) / a
                if EPS_CYLINDER <= root_ <= closest and 0 < ray.atY(root_) < h:
                    root = root_
            if root < Inf:
                point = ray.at(root)
                norm = Vec3(point.x, 0, point.z) / r
        return root, norm
